{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "<a href=\"https://colab.research.google.com/github/lollcat/fab-torch/blob/master/demo/many_well.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "a3e5dcf4be772f9b"
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Install fab-torch repo"
   ],
   "metadata": {
    "id": "-2z7-wbmQgVS"
   },
   "id": "-2z7-wbmQgVS"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94b5026c-6d96-464c-894e-c3d2be6ec58b",
   "metadata": {
    "id": "94b5026c-6d96-464c-894e-c3d2be6ec58b"
   },
   "outputs": [],
   "source": [
    "# If using colab then run this cell.\n",
    "!git clone https://github.com/lollcat/fab-torch\n",
    "\n",
    "import os\n",
    "os.chdir(\"fab-torch\")\n",
    "\n",
    "!pip install --upgrade ."
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Download weights from huggingface and run example of inference\n",
    "We can just use CPU as the model is not that expensive."
   ],
   "metadata": {
    "id": "xy2GWTB7QlxO"
   },
   "id": "xy2GWTB7QlxO"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d861b2f8-00be-4e14-998c-348ec89d1c89",
   "metadata": {
    "id": "d861b2f8-00be-4e14-998c-348ec89d1c89"
   },
   "outputs": [],
   "source": [
    "# Restart after install, then run the below code\n",
    "import os\n",
    "import urllib\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import rc\n",
    "import matplotlib as mpl\n",
    "from hydra import compose, initialize\n",
    "import torch\n",
    "\n",
    "from fab.utils.plotting import plot_contours, plot_marginal_pair\n",
    "from fab.target_distributions.many_well import ManyWellEnergy\n",
    "from experiments.setup_run import setup_model\n",
    "from experiments.many_well.many_well_visualise_all_marginal_pairs import get_target_log_prob_marginal_pair"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66a8cf07-5d35-4368-a320-fc63e9842d7d",
   "metadata": {
    "id": "66a8cf07-5d35-4368-a320-fc63e9842d7d"
   },
   "outputs": [],
   "source": [
    "with initialize(version_base=None, config_path=\"fab-torch/experiments/config/\", job_name=\"colab_app\"):\n",
    "    cfg = compose(config_name=f\"many_well\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cc811fd-2208-4111-b428-dfcad4b0bf7c",
   "metadata": {
    "id": "8cc811fd-2208-4111-b428-dfcad4b0bf7c"
   },
   "outputs": [],
   "source": [
    "target = ManyWellEnergy(cfg.target.dim, a=-0.5, b=-6, use_gpu=False)\n",
    "model = setup_model(cfg, target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c19b7e70-c0ba-49d0-8aa0-ed77001c8e95",
   "metadata": {
    "id": "c19b7e70-c0ba-49d0-8aa0-ed77001c8e95"
   },
   "outputs": [],
   "source": [
    "# Download weights from huggingface, and load them into the model\n",
    "urllib.request.urlretrieve('https://huggingface.co/VincentStimper/fab/resolve/main/many_well/model.pt', 'model.pt')\n",
    "model.load(\"model.pt\", map_location=\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4077f8a-d5bc-45f5-b401-7b807c4e68ed",
   "metadata": {
    "id": "f4077f8a-d5bc-45f5-b401-7b807c4e68ed"
   },
   "outputs": [],
   "source": [
    "# Sample from the model\n",
    "n_samples: int = 200\n",
    "samples_flow = model.flow.sample((n_samples,)).detach()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f338bd95-3ea2-4df7-a700-cb8529a26914",
   "metadata": {
    "id": "f338bd95-3ea2-4df7-a700-cb8529a26914"
   },
   "outputs": [],
   "source": [
    "# Visualise samples\n",
    "alpha = 0.3\n",
    "plotting_bounds = (-3, 3)\n",
    "dim = cfg.target.dim\n",
    "fig, axs = plt.subplots(2, 2, sharex=\"row\", sharey=\"row\")\n",
    "\n",
    "for i in range(2):\n",
    "    for j in range(2):\n",
    "        target_log_prob = get_target_log_prob_marginal_pair(target.log_prob, i, j + 2, dim)\n",
    "        plot_contours(target_log_prob, bounds=plotting_bounds, ax=axs[i, j],\n",
    "                      n_contour_levels=20, grid_width_n_points=100)\n",
    "        plot_marginal_pair(samples_flow, marginal_dims=(i, j+2),\n",
    "                           ax=axs[i, j], bounds=plotting_bounds, alpha=alpha)\n",
    "\n",
    "\n",
    "        if j == 0:\n",
    "            axs[i, j].set_ylabel(f\"$x_{i + 1}$\")\n",
    "        if i == 1:\n",
    "            axs[i, j].set_xlabel(f\"$x_{j + 1 + 2}$\")"
   ]
  },
  {
   "cell_type": "code",
   "source": [],
   "metadata": {
    "id": "MBK1xcr8UDui"
   },
   "id": "MBK1xcr8UDui",
   "execution_count": null,
   "outputs": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  },
  "colab": {
   "provenance": [],
   "include_colab_link": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
